0%

(NIPS 2015) Spatial transformer networks

Keyword [STN]

Jaderberg M, Simonyan K, Zisserman A. Spatial transformer networks[C]//Advances in neural information processing systems. 2015: 2017-2025.


1. Overview


  • 虽然CNN的效果很好,但是仍然缺乏对数据的空间不变能力,从而限制了计算和参数的效率。因此,论文提出Spatial Transformer Network (STN)。


1.1. STN

  • 在网络中对数据显式地进行空间操作(平移、旋转、缩放、裁剪、扭曲)。由于该操作可微,因此模型能够end to end训练。
  • 根据输入数据,动态生成空间操作参数Θ
  • 网络参数直接通过loss回传进行学习。可直接添加到神经网络模型中,整个训练不需额外的监督信息加入。
  • 空间操作后的数据是与后续特定任务高度相关的。另一方面,变换后的低分辨率数据比原始数据的计算效率更高。
  • 通过对数据进行操作实现不变性,而不是对特征提取器(卷积核)。

1.2. 适用的任务

  • classification
  • co-localization
  • spatial attention

2. Spatial Transformers




STN包含3部分 (Figure 2)

  • localization network.
  • grid generator.
  • sampler.

2.1. Localization Network

  • 输入U(h, w, c)
  • 输出空间变换参数Θ

网络可以是任何形式,如FCN、CNN等。仿射变换Θ的参数为6,投影变换参数为8,以及thin plate spline (TPS). 模型对最后一层的weight矩阵初始化为0,bias初始化为[[1, 0, 0], [0, 1, 0]](仿射变换),即全等变换

2.2. Parameterised Sampling Grid



  • 首先根据采样网格大小(超参数)生成标准网格(t; x,y∈(-1, 1); (h, w, 2)).
  • 利用空间变换参数Θ对其进行变换操作,生成采样网格(s; x,y∈(-1, 1); (h, w, 2)).


2.3. Differentiable Image Sampling

  • 通用的采样公式可写为


  • k为通用采样kernel; x, m, y, n为坐标点。Φ为kernel的参数。
  1. 对于整数采样kernel,公式简化为


  • 取x+0.5下界整数,δ函数为Kronecker delta函数


  1. 对于双线性采样kernel,公式简化为


  • 该公式可导


2.4. Spatial Transformer Networks

  • 由于Θ显式地编码了变换,因此也可将Θ传入后续的网络,而非变换后的特征图(或图片)。
  • 可用STN对特征图进行上采样或下采样。但是,用固定的、小空间支持的采样kernel(双线性kernel)进行下采样会造成影响
  • STN可级联或并行在网络中。




3. Experiments


3.1. Distorted MNIST

  • 数据集distorted方式分为
    • R 旋转,±90°之间。
    • RTS 旋转+缩放+平移
    • P 投影
    • E 弹性形变(破坏性,不可逆)
  • 所有模型都具有相同数量参数,分别使用3类变换操作:仿射变换(Aff)、投影变换(Proj)、薄板样条变换(TPS)。实验发现TPS最有效。


3.2. MNIST Addition

  • 输入两张数字图片(h,w,2),输出数字的和。


3.3. Street View House Numbers

  • 每张图片有1~5个数字。因此,模型采用级联STN,并使用5个独立的softmax分类器,每个分类器包含一个空字符


3.4. Fine-Grained Classification

  • CUB-200-2011数据集,模型采用并行STN结构。




3.5. Co-localization

  • 使用半监督学习来定位图像中的物体。基于正确定位对象A与正确定位对象B之间的距离,比A与随机定位crop小的假设,构造hinge loss


  • T表示crop,e为编码函数,α为margin,实验设置为1。数据集的构建操作为:将2828的数字图片放在8484背景中,并将从训练集中采样得到的16个随机6*6 crop放入背景中。当预测定位与ground-truth的交集大于0.5时,定义为预测正确。



3.6. Higher Dimensionnal Transformer

  • 模型使用3D仿射变换和3D双线性插值操作。





  • 另一种处理方法是:将3D空间投影到2D空间,例如